{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import gzip, os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import keras\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense\n",
    "from keras import layers\n",
    "from keras.optimizers import Adam\n",
    "from keras.utils import to_categorical\n",
    "%matplotlib inline\n",
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the test data from MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60000, 28, 28, 1)\n",
      "(60000,)\n"
     ]
    }
   ],
   "source": [
    "def load_data(filename, num_images=60000):\n",
    "    with gzip.open(filename) as bytestream:\n",
    "        bytestream.read(16)\n",
    "        buf = bytestream.read(28 * 28 * num_images)\n",
    "        data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)\n",
    "        data = data.reshape(num_images, 28, 28, 1)\n",
    "        mean = np.mean(data, axis=0)\n",
    "        data -= mean\n",
    "    return data\n",
    "\n",
    "def load_label(filename, num_images=60000):\n",
    "    with gzip.open(filename) as bytestream:\n",
    "        bytestream.read(8)\n",
    "        buf = bytestream.read(num_images)\n",
    "        labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)\n",
    "    return labels\n",
    "\n",
    "data = load_data(os.path.join('data', 'train-images-idx3-ubyte.gz'))\n",
    "print(data.shape)\n",
    "labels = load_label(filename = os.path.join('data', 'train-labels-idx1-ubyte.gz'))\n",
    "print(labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 28, 28, 1)\n",
      "(10000,)\n"
     ]
    }
   ],
   "source": [
    "test_data = load_data(os.path.join('data', 't10k-images-idx3-ubyte.gz'), num_images=10000)\n",
    "print(test_data.shape)\n",
    "test_labels = load_label(os.path.join('data', 't10k-labels-idx1-ubyte.gz'), num_images=10000)\n",
    "print(test_labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split the training set to training and validation\n",
    "Shuffle the indicies and a random training set and random validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(48000, 28, 28, 1) (12000, 28, 28, 1)\n"
     ]
    }
   ],
   "source": [
    "encoded_labels = to_categorical(labels)\n",
    "\n",
    "indicies = np.arange(data.shape[0])\n",
    "np.random.shuffle(indicies)\n",
    "\n",
    "data = data[indicies]\n",
    "encoded_labels = encoded_labels[indicies]\n",
    "\n",
    "num_trains = int(.8 * data.shape[0])\n",
    "\n",
    "mask = range(num_trains)\n",
    "train_data = data[mask]\n",
    "train_labels = encoded_labels[mask]\n",
    "\n",
    "mask = range(num_trains, data.shape[0])\n",
    "val_data = data[mask]\n",
    "val_labels = encoded_labels[mask]\n",
    "\n",
    "print(train_data.shape, val_data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.text.Text at 0x23534ca3748>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAE2ZJREFUeJzt3X2MXNV5BvDnmdldez/8DayN7WAQDo5DVIO2FAlK3dJQ\nsKraqCkCodSVUJ0/CC1pqpbSSuGPSkFVA6IpRVkCjYkIaVSCQMIlsS0QJQTCmrrGfAVwbLBjbFyD\nvfZ+z7z9Yy7RBva+Z9h7Z+6sz/OTVrs779y5x7P7+M7Oe889NDOISHxKRQ9ARIqh8ItESuEXiZTC\nLxIphV8kUgq/SKQUfpFIKfwyJZJPkRwheSL5eL3oMUm+FH7xfNnMepKP84oejORL4ReJlMIvnq+T\nPELyJyTXFj0YyRd1br9MheRvAXgFwBiAawH8K4A1ZvZWoQOT3Cj8UheSTwB43My+WfRYJB962S/1\nMgAsehCSH4VfPobkfJJ/QHI2yTaS1wO4DMATRY9N8tNW9ACkJbUD+EcAqwBUALwGYIOZ/bzQUUmu\n9De/SKT0sl8kUgq/SKQUfpFIKfwikWrqu/1tnd3WPm9hM3cpEpXxY0cxMXyyrvMxMoWf5JUA7gJQ\nBvBtM7vdu3/7vIU49/q/yrJLEXG8+eAddd932i/7SZYB3A3gKgCrAVxHcvV0H09EmivL3/wXAXjT\nzPaY2RiA7wNYn8+wRKTRsoR/KYB3Jn2/P7nt15DcRHKA5EBl6GSG3YlInhr+br+Z9ZtZn5n1lbu6\nG707EalTlvAfALB80vfLkttEZAbIEv4XAKwkeTbJDtQu+PBYPsMSkUabdqvPzCZIfhnAj1Br9d1v\nZi/nNjLJRyvP29LVAQqVqc9vZlsAbMlpLCLSRDq9VyRSCr9IpBR+kUgp/CKRUvhFIqXwi0RKV+9t\nhiJ77Rn3zQLHblnPA2jkeQSnwDkKOvKLRErhF4mUwi8SKYVfJFIKv0ikFH6RSKnVV69GtrwCjx1s\ntzl1VjPuu+rfIUsrMNTKs1LgDoFDlzn1YBsxVA/9u2dAK1BHfpFIKfwikVL4RSKl8ItESuEXiZTC\nLxIphV8kUurzfyhLHz/YK89WL00Eeu0T6bXyuP/Y5TH/sUvjgX2HziNw+t3Vsr9ppcNvlofq1Xav\n5m/rnSNQTz2TJp0joCO/SKQUfpFIKfwikVL4RSKl8ItESuEXiZTCLxKpePr8WS9h7fSzg336QK88\n1DPu2HDYrR97pje1tviFMXfbthP+iQDlIb/O8YpbN6Y3rW22/+s30dPh1sfn+NuP9aQ/sROd7qao\n+LuGtQXOMQicw9AK8/0zhZ/kXgCDACoAJsysL49BiUjj5XHk/10zO5LD44hIE+lvfpFIZQ2/AdhG\ncgfJTVPdgeQmkgMkBypDJzPuTkTykvVl/6VmdoDkGQC2knzNzJ6efAcz6wfQDwCdi5cXuWqdiEyS\n6chvZgeSz4cBPALgojwGJSKNN+3wk+wmOefDrwFcAWB3XgMTkcbK8rK/F8AjrPVx2wB8z8yeyGVU\n09HAPj4A0Glnh+bbl0f9xx6/+n23fs+q77n1r7Rdk1qb9ajf0C4dO+HWbWjIr48FzgNw+vylrsDY\n5va49YlzF7r1kUXpx7axOe6mqHSGrnPgb9/xgV/PdD2AnM4RmHb4zWwPgN/IZxgi0mxq9YlESuEX\niZTCLxIphV8kUgq/SKQ0pTcRbvWlP0DJnzUbnD76D5/Z4tZ/MnyuWz/28Jmpta4je9xtrafLraPL\nb6eNnznXf3yn1Te43J83+95a/4n9zIoDbv3mZU+n1s5r96dJn9s+y60fqgy79cue/Au3PufF2W69\nGXTkF4mUwi8SKYVfJFIKv0ikFH6RSCn8IpFS+EUiNbP6/Bmm7TK0baBecqb0lgOX5h68yr98Wbu3\nxjaAoarfDz9/48uptZ/93lnutl2z/V56Z4c/d3Vpzz63XrX0Pv+iwA9lVWB98SWzj7v1vxn44/R9\nz/enMp8c9Z/z0d3z3XrZWR68VejILxIphV8kUgq/SKQUfpFIKfwikVL4RSKl8ItEamb1+T0Z+/jB\nujPf//1V/rWU7+37rlvfM3aGW+9//Aq3bstGUmuLFx1zt+07/W233hW4WMEHE/71ALZtuyC1Frr8\ndfd+v35ol3/+xNnOD3V0kX+dApvvr7Hd3u2WUe3IuIR3E+jILxIphV8kUgq/SKQUfpFIKfwikVL4\nRSKl8ItE6tTp84dkXNbYm3re8Vm/lx5y1ze/4NY//V+/9B+gmn4SgnX6159/bfZ5/mO3+ccHjjsX\nOgBw7uiR9G2H0s9PAABU/ZMvrMdfEGFiYXozvtrm/0JYKfALEyg7lzGoa/tmCB75Sd5P8jDJ3ZNu\nW0hyK8k3ks8LGjtMEclbPS/7vwPgyo/cdguA7Wa2EsD25HsRmUGC4TezpwEc/cjN6wFsTr7eDGBD\nzuMSkQab7ht+vWZ2MPn6XQC9aXckuYnkAMmBypB/LraINE/md/vNzOBMizGzfjPrM7O+cldgNoSI\nNM10w3+I5BIASD77S56KSMuZbvgfA7Ax+XojgEfzGY6INEuwz0/yIQBrAZxGcj+ArwG4HcAPSN4A\nYB+Aaxo5yLqE+qYZrvkf8jvL3nLrQ1W/177oFb/fHew5t6dfY77a5V9/vjrL/xVgoNfOwJx8TDjn\nAVSciyQAQNk/NlnZnxRf7UjfPtznd8uN7eM36RyAYPjN7LqU0uU5j0VEmkin94pESuEXiZTCLxIp\nhV8kUgq/SKTimdKbldPxGq36T+OI+es1h9ppYKD30+a0vEr+/++lcb/dxmG/l1c6OezWbSi9bhP+\n0uTs9KfshqYbV5xWn5WzTdk9FejILxIphV8kUgq/SKQUfpFIKfwikVL4RSKl8ItEamb1+b3ea6BV\nHpqCWbLA1FWnHT4aWG95pOr3+Y+fNdutLzp43K17OBbopQ/7S3Bz0L/0WvWEX7ex9MdnYEouZ/vP\niwX6/N5U6OCU3ADvUu61HWR7/GbQkV8kUgq/SKQUfpFIKfwikVL4RSKl8ItESuEXidTM6vNnkLUv\n6/X5n9/+WXfbdX/ykltf+5WfuvWDI3PdumcscK2BXU+scuudh/0npuO4X5/70HOptVAfP3QtgtB1\nDryfefAaCqETAWZAHz9ER36RSCn8IpFS+EUipfCLRErhF4mUwi8SKYVfJFLR9PlDQucBeH3hxc/5\n176/8/Vr3foHfqsdl17unycw7qwnPTThL9H96d/3lxc/Gdi+GuiHlzedmVp7d3COu+3gKwvd+uyj\n/r4XvJ5+LYPgz/sU6OOHBI/8JO8neZjk7km33UbyAMmdyce6xg5TRPJWz8v+7wC4corb7zSzNcnH\nlnyHJSKNFgy/mT0N4GgTxiIiTZTlDb+bSO5K/ixYkHYnkptIDpAcqAz513sTkeaZbvjvAXAOgDUA\nDgL4RtodzazfzPrMrK/c1T3N3YlI3qYVfjM7ZGYVM6sCuBfARfkOS0QabVrhJ7lk0rdXA9iddl8R\naU3BPj/JhwCsBXAayf0AvgZgLck1qM1q3gvgSw0cYz4yzNcHAn3fwNTvOe+MuvX5r4+79QP/eZZb\n53gltVYaGnG3PXphr1uvtmW7wP2+30yvdZ7jr0fwR59/3q1/rnu/W//Z4DmptacevdDdds7boV8Y\nv5xJ6ByDnPYdDL+ZXTfFzffls3sRKYpO7xWJlMIvEimFXyRSCr9IpBR+kUhFM6U385LKTt2ZUQsg\n3C7jLP/HwLLfhyyV0x+/Cv/y2PN2vOvWMey3Cq3ij23BU86TM8ufLvzTi/1zx35x4yK3fvOyram1\nrav9edT2Tqdfb2Srr0l05BeJlMIvEimFXyRSCr9IpBR+kUgp/CKRUvhFIjWz+vxZLqfcyCm9oV07\nffh66qwGtm9z/g9vL/vbdrT7+55Iny4MAIQ/HTm0jLZn/q7/c+tHv77CrXfcnT72xy+52932D/f9\ntVvvedstZ9Okcwh05BeJlMIvEimFXyRSCr9IpBR+kUgp/CKRUvhFIjWz+vyerEsqZ5jPT78VjtKE\n/+Cc8E8y8C7NXdt/+vbBbcfTl7EGAIz7fXwL1L0+P8v+OQihcwROnOn/+naV0sc2WPXPb2g7GTi3\nItSLnwHz/XXkF4mUwi8SKYVfJFIKv0ikFH6RSCn8IpFS+EUiVc8S3csBPACgF7Vud7+Z3UVyIYD/\nALACtWW6rzGz9xs31AbL0Jdl1e/jl8b8Pn55xO+1l4YDvfSx9DpHxtxNbWjIr4/4y4uHlE5Pv7b+\n219Y5m7LS/1fp3/53Lfc+nuV7tTaTf/uryrf+Z7/M622ZzwPoAXUc+SfAPBVM1sN4GIAN5JcDeAW\nANvNbCWA7cn3IjJDBMNvZgfN7MXk60EArwJYCmA9gM3J3TYD2NCoQYpI/j7R3/wkVwC4AMDzAHrN\n7GBSehe1PwtEZIaoO/wkewA8DOBmMzs+uWZmhpSz30luIjlAcqAydDLTYEUkP3WFn2Q7asF/0Mx+\nmNx8iOSSpL4EwOGptjWzfjPrM7O+clf6GzAi0lzB8JMkgPsAvGpmd0wqPQZgY/L1RgCP5j88EWmU\neqb0XgLgiwBeIrkzue1WALcD+AHJGwDsA3BNY4aYk0DrZWShf4e2EW9Ob2DfFpjSO+a3+jgcaLcN\nDafv2qkBADr9JbyxcrFbnpg3y60P/d2x1NqfLvuRu21XyW9TPntypVt/8KHL0x/7iP8zqXQEWnmn\nwBkywfCb2TNI//VOf3ZFpKWdAv9/ich0KPwikVL4RSKl8ItESuEXiZTCLxKpU+fS3YFee6gvu+76\nZ936vHJ6v3z/6AJ32x+/scqtz/1vf/sTn3LLGF/s9MPH/H94+1y/l37hp95x62fMGnTrK2anL7O9\nc9Cf0juw5Xy3PvcX/lTpzs70Xn4lNCU38PuiS3eLyIyl8ItESuEXiZTCLxIphV8kUgq/SKQUfpFI\nzaw+v9c7DSyxHerLPvlLf274tWftSK3929Ln3G33925z63su7nHre8dO8x9/LP3y2Mcqne62Zfi9\n8nlt/vUAvvU/v+3W5z+bPt+/3b9qOHro/1DHu/wfqnd5bQusDt7wPn4LnAegI79IpBR+kUgp/CKR\nUvhFIqXwi0RK4ReJlMIvEqmZ1ef3ZOybjv74dLe+uXRlau2+wKXvS4EVtsuBy/KXR/1+t7d9qRI4\nASKg2uY/sYuHQ9e/n14NAKrtft0CY/Pm5MfQxw/RkV8kUgq/SKQUfpFIKfwikVL4RSKl8ItESuEX\niVSwz09yOYAHAPSiNmu+38zuInkbgD8H8F5y11vNbEujBppZ1r6r085uPxnYtT9lHoEp9cGxV51+\neTXY0PaFrl9fmZWh117Kdu384FoNXj2CPn5IPSf5TAD4qpm9SHIOgB0ktya1O83snxs3PBFplGD4\nzewggIPJ14MkXwWwtNEDE5HG+kR/85NcAeACAM8nN91EchfJ+0lOueYUyU0kB0gOVIYCr49FpGnq\nDj/JHgAPA7jZzI4DuAfAOQDWoPbK4BtTbWdm/WbWZ2Z95a7uHIYsInmoK/wk21EL/oNm9kMAMLND\nZlYxsyqAewFc1LhhikjeguEnSQD3AXjVzO6YdPuSSXe7GsDu/IcnIo1Sz7v9lwD4IoCXSO5MbrsV\nwHUk16DWBNsL4EsNGWGzhFo3Tqsv63LODLXTyv4DBK5wnUnGTqH7vDbysTM7BVp5IfW82/8Mpn4q\nWrenLyJBOsNPJFIKv0ikFH6RSCn8IpFS+EUipfCLROrUuXR3o2Xp+4amnmbs0zewzd/aIujFN5KO\n/CKRUvhFIqXwi0RK4ReJlMIvEimFXyRSCr9IpGhZm8yfZGfkewD2TbrpNABHmjaAT6ZVx9aq4wI0\ntunKc2xnmZm/3nyiqeH/2M7JATPrK2wAjlYdW6uOC9DYpquosellv0ikFH6RSBUd/v6C9+9p1bG1\n6rgAjW26ChlboX/zi0hxij7yi0hBFH6RSBUSfpJXknyd5JskbyliDGlI7iX5EsmdJAcKHsv9JA+T\n3D3ptoUkt5J8I/k85RqJBY3tNpIHkuduJ8l1BY1tOcknSb5C8mWSf5ncXuhz54yrkOet6X/zkywD\n+DmAzwPYD+AFANeZ2StNHUgKknsB9JlZ4SeEkLwMwAkAD5jZ+clt/wTgqJndnvzHucDM/rZFxnYb\ngBNFL9uerCa1ZPKy8gA2APgzFPjcOeO6BgU8b0Uc+S8C8KaZ7TGzMQDfB7C+gHG0PDN7GsDRj9y8\nHsDm5OvNqP3yNF3K2FqCmR00sxeTrwcBfLisfKHPnTOuQhQR/qUA3pn0/X4U+ARMwQBsI7mD5Kai\nBzOFXjM7mHz9LoDeIgczheCy7c30kWXlW+a5m85y93nTG34fd6mZrQFwFYAbk5e3Lclqf7O1Uq+2\nrmXbm2WKZeV/pcjnbrrL3eetiPAfALB80vfLkttagpkdSD4fBvAIWm/p8UMfrpCcfD5c8Hh+pZWW\nbZ9qWXm0wHPXSsvdFxH+FwCsJHk2yQ4A1wJ4rIBxfAzJ7uSNGJDsBnAFWm/p8ccAbEy+3gjg0QLH\n8mtaZdn2tGXlUfBz13LL3ZtZ0z8ArEPtHf+3APx9EWNIGdc5AP43+Xi56LEBeAi1l4HjqL03cgOA\nRQC2A3gDwDYAC1tobN8F8BKAXagFbUlBY7sUtZf0uwDsTD7WFf3cOeMq5HnT6b0ikdIbfiKRUvhF\nIqXwi0RK4ReJlMIvEimFXyRSCr9IpP4f7ErKXNRkDV4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x23534837550>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "n = np.random.choice(train_data.shape[0])\n",
    "plt.imshow(train_data[n].reshape(28,28))\n",
    "plt.title(np.argmax(train_labels[n]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Keras model\n",
    "Similar model as Tensorflow\n",
    "\n",
    "CONV3 - MAX POOL - CONV3 - MAX POOL - FC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "\n",
    "model.add(Conv2D(filters=10, kernel_size=3, strides=1, padding=\"same\", activation='relu', input_shape=(28,28,1)))\n",
    "model.add(MaxPooling2D(pool_size=2, strides=2, padding='valid'))\n",
    "model.add(Conv2D(filters=10, kernel_size=3, strides=1, padding=\"same\", activation='relu'))\n",
    "model.add(MaxPooling2D(pool_size=2, strides=2, padding='valid'))\n",
    "model.add(Flatten())\n",
    "model.add(Dense(units=10, activation=\"softmax\"))\n",
    "\n",
    "adam = Adam(lr=1e-3)\n",
    "model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit the model\n",
    "Run 5 epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 48000 samples, validate on 12000 samples\n",
      "Epoch 1/5\n",
      "48000/48000 [==============================] - 11s - loss: 6.1597 - acc: 0.5928 - val_loss: 0.3547 - val_acc: 0.9321\n",
      "Epoch 2/5\n",
      "48000/48000 [==============================] - 9s - loss: 0.2009 - acc: 0.9512 - val_loss: 0.1397 - val_acc: 0.9626\n",
      "Epoch 3/5\n",
      "48000/48000 [==============================] - 9s - loss: 0.1064 - acc: 0.9700 - val_loss: 0.1149 - val_acc: 0.9677\n",
      "Epoch 4/5\n",
      "48000/48000 [==============================] - 9s - loss: 0.0837 - acc: 0.9745 - val_loss: 0.1077 - val_acc: 0.9721\n",
      "Epoch 5/5\n",
      "48000/48000 [==============================] - 9s - loss: 0.0740 - acc: 0.9774 - val_loss: 0.0867 - val_acc: 0.9767\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x2354de7eef0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(\n",
    "    train_data, \n",
    "    train_labels, \n",
    "    batch_size=32, \n",
    "    epochs=5, \n",
    "    validation_data=(val_data, val_labels)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run against the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "encoded_test_labels = to_categorical(test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 9152/10000 [==========================>...] - ETA: 0s"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.080617253285669724, 0.97719999999999996]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_data, encoded_test_labels, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "preds = model.predict(test_data, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6 0 5 4 9 9 2 1 9 4 8 7 3 9 7 4 4 4 9 2]\n",
      "[6 0 5 4 9 9 2 1 9 4 8 7 3 9 7 4 4 4 9 2]\n"
     ]
    }
   ],
   "source": [
    "print(np.argmax(preds[100:120],axis=1))\n",
    "print(test_labels[100:120])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
